import torch
import torch.nn as nn
from torch import FloatTensor
from torch.nn.parameter import Parameter
from scipy.spatial.distance import pdist, squareform
import torch.nn.functional as F
import numpy as np
import math
import matplotlib.colors as mcolors
import os 

class DistanceAdj(nn.Module):
    def __init__(self, sigma, bias):
        super(DistanceAdj, self).__init__()
        # self.sigma = sigma
        # self.bias = bias
        self.w = nn.Parameter(torch.FloatTensor(1))
        self.b = nn.Parameter(torch.FloatTensor(1))
        self.w.data.fill_(sigma)
        self.b.data.fill_(bias)

    def forward(self, batch_size, seq_len):
        arith = np.arange(seq_len).reshape(-1, 1)
        dist = pdist(arith, metric='cityblock').astype(np.float32)
        dist = torch.from_numpy(squareform(dist)).cuda()
        # dist = torch.exp(-self.sigma * dist ** 2)
        dist = torch.exp(-torch.abs(self.w * dist ** 2 - self.b))
        dist = torch.unsqueeze(dist, 0).repeat(batch_size, 1, 1)

        return dist


class TCA(nn.Module):
    def __init__(self, d_model, dim_k, dim_v, n_heads, norm=None):
        super(TCA, self).__init__()
        self.dim_v = dim_v
        self.dim_k = dim_k  # same as dim_q
        self.n_heads = n_heads
        self.norm = norm

        self.q = nn.Linear(d_model, dim_k)
        self.k = nn.Linear(d_model, dim_k)
        self.v = nn.Linear(d_model, dim_v)
        self.o = nn.Linear(dim_v, d_model)

        self.norm_fact = 1 / math.sqrt(dim_k)
        self.alpha = nn.Parameter(torch.tensor(0.))
        self.act = nn.Softmax(dim=-1)

    def forward(self, x, mask, adj=None, scale=None):
        Q = self.q(x).view(-1, x.shape[0], x.shape[1], self.dim_k // self.n_heads)
        K = self.k(x).view(-1, x.shape[0], x.shape[1], self.dim_k // self.n_heads)
        V = self.v(x).view(-1, x.shape[0], x.shape[1], self.dim_v // self.n_heads)

        if adj is not None:
            g_map = torch.matmul(Q, K.permute(0, 1, 3, 2)) * self.norm_fact + adj
        else:
            g_map = torch.matmul(Q, K.permute(0, 1, 3, 2)) * self.norm_fact
        l_map = g_map.clone()
        l_map = l_map.masked_fill_(mask.data.eq(0), -1e9)

        g_map = self.act(g_map)
        l_map = self.act(l_map)
        glb = torch.matmul(g_map, V).view(x.shape[0], x.shape[1], -1)
        lcl = torch.matmul(l_map, V).view(x.shape[0], x.shape[1], -1)
        
        var_tensor = l_map.float().masked_fill(l_map == 0, float('nan'))
        var = np.nanvar(np.array(var_tensor.cpu()), axis=3)
        
        mask = var < 0.0015
        # lcl[mask] = 0
        # plot_map(l_map[0, 0, :30, :30].cpu().detach().numpy())
        alpha = torch.sigmoid(self.alpha)
        alpha = 0
        tmp = alpha * glb + (1 - alpha) * lcl
        if self.norm:
            tmp = torch.sqrt(F.relu(tmp)) - torch.sqrt(F.relu(-tmp))  # power norm
            tmp = F.normalize(tmp)  # l2 norm
        tmp = self.o(tmp).view(-1, x.shape[1], x.shape[2])
        # tmp[mask] = 0
        idx = torch.argmax(torch.sum(l_map, dim=2)[0, 0, :l_map.shape[2]//4+1])
        idx1 = torch.argmax(torch.sum(l_map, dim=2)[0, 0, -l_map.shape[2]//10-1:]) + (l_map.shape[2]-l_map.shape[2]//10-2)
        return tmp, idx, idx1

import matplotlib.pyplot as plt

def plot_map(x):
    plt.figure()
    # colors = ['#F6F9FC', '#0066CC']
    # cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap', [colors[0], colors[1]])
    plt.imshow(x, cmap='cividis', interpolation='nearest')
    # plt.colorbar()
    plt.axis('off')
    plt.show()
    def find_max_number(folder_path):
        max_number = float('-inf')
        for filename in os.listdir(folder_path):
            if filename.endswith('.png'):
                file_number = int(filename.split('.')[0])
                max_number = max(max_number, file_number)
        return max_number
    index = find_max_number('analyse_tools_ucf/figs/MM_Map')
    plt.savefig(f'analyse_tools_ucf/figs/MM_Map/{index+1}.png', dpi=500)
    print(f'analyse_tools_ucf/figs/MM_Map/{index+1}.png')